# 模型训练
import torch
from MyData import MyDataset
from torch.utils.data import DataLoader
from net import Model
from transformers import BertTokenizer, AutoTokenizer
from torch.optim import AdamW

#定义设备信息
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
#定义训练的轮次
EPOCH= 30000

token = BertTokenizer.from_pretrained(
    r"D:\AIModels\llm\bert-base-chinese\models--bert-base-chinese\snapshots\c30a6ed22ab4564dc1e3b2ecbf6e766b0611a33f"
    )

def collate_fn(data): # 对数据进行编码
    """
    对数据进行编码
    :param data:
    :return:
    """
    sents = [i[0]for i in data]
    label = [i[1] for i in data]
    #编码
    data = token.batch_encode_plus(
        batch_text_or_text_pairs=sents, #输入的文本
        truncation=True, #当句子长度大于max_length时，截断
        max_length=500, #设定最大序列长度
        padding="max_length", #填充到最大的序列长度， 不足的部分用0填充
        return_tensors="pt", #返回的类型, 如果为None，返回list， 如果为tf，返回tf张量， 如果为pt，返回pytorch张量， 如果为np，返回numpy数组
        return_length=True #返回length长度
    )

    input_ids = data["input_ids"] # 编码后的词
    attention_mask = data["attention_mask"] # 有有效值的部分为1，其他部分为0
    token_type_ids = data["token_type_ids"] # 第一个句子和特殊符号的位置是0，第二个句子的位置1（）只针对于上下文编码
    # special_tokens_mask 特殊符号的位置是1，其他位置是0
    labels = torch.LongTensor(label)  #将标签转换为张量

    return input_ids,attention_mask,token_type_ids,labels


#创建数据集
train_dataset = MyDataset("train")
train_loader = DataLoader(
    dataset=train_dataset,
    batch_size=100,
    shuffle=True,
    #舍弃最后一个批次的数据，防止形状出错
    drop_last=True,
    #对加载进来的数据进行编码
    collate_fn=collate_fn
)

val_dataset = MyDataset("validation")
val_loader = DataLoader(
    dataset=val_dataset,
    batch_size=100,
    shuffle=True,
    #舍弃最后一个批次的数据，防止形状出错
    drop_last=True,
    #对加载进来的数据进行编码
    collate_fn=collate_fn
)

if __name__ == '__main__':
    #开始训练
    print(DEVICE)
    model = Model().to(DEVICE)
    #定义优化器
    optimizer = AdamW(model.parameters())  #AdamW优化器，用于优化模型参数
    #定义损失函数
    loss_func = torch.nn.CrossEntropyLoss()  #交叉熵损失函数

    #初始化最佳验证准确率
    best_val_acc = 0.0

    #开始训练
    model.train()
    for epoch in range(EPOCH):
        for i,(input_ids,attention_mask,token_type_ids,labels) in enumerate(train_loader):
            #将数据存放到DEVICE上
            input_ids, attention_mask, token_type_ids, labels = \
                input_ids.to(DEVICE),attention_mask.to(DEVICE),token_type_ids.to(DEVICE),labels.to(DEVICE)
            #前向计算（将数据输入模型，得到输出）, 
            out = model(input_ids=input_ids,attention_mask=attention_mask,token_type_ids=token_type_ids)
            #根据输出，计算损失
            loss = loss_func(out, labels)
            #根据损失，优化参数
            optimizer.zero_grad() # 清空梯度, 防止梯度累积
            loss.backward() # 反向传播, 计算梯度
            optimizer.step() # 更新参数, 根据梯度更新参数

            #每隔5个批次输出训练信息
            if i%5==0:
                out = out.argmax(dim=1)
                acc = (out==labels).sum().item()/len(labels)
                print(f"epoch:{epoch},i:{i},loss:{loss.item()},acc:{acc}")
        #验证模型(判断是否过拟合)
        #设置为评估模式
        model.eval()
        #不需要模型参与训练
        with torch.no_grad(): # grad: gradient 斜率，梯度；
            val_acc = 0.0
            val_loss = 0.0
            for i, (input_ids, attention_mask, token_type_ids, labels) in enumerate(val_loader):
                # 将数据存放到DEVICE上
                input_ids, attention_mask, token_type_ids, labels = \
                    input_ids.to(DEVICE), attention_mask.to(DEVICE), token_type_ids.to(DEVICE), labels.to(DEVICE)
                # 前向计算（将数据输入模型，得到输出）
                out = model(input_ids=input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids)
                # 根据输出，计算损失
                val_loss += loss_func(out, labels)
                out = out.argmax(dim=1) # 取概率最大的类别作为预测类别， dim: 0为横向，1为纵向
                val_acc+=(out==labels).sum().item() #
            val_loss /= len(val_loader)
            val_acc /= len(val_loader)
            print(f"验证集：loss:{val_loss},acc：{val_acc}")

            #根据验证准确率保存最优参数
            if val_acc > best_val_acc:
                best_val_acc = val_acc
                torch.save(model.state_dict(),"params/best_bert.pth")
                print(f"Epoch:{epoch}:保存最优参数：acc:{best_val_acc}")

        #保存最后一轮参数
        torch.save(model.state_dict(),f"params/last_bert.pth")
        print(epoch,f"Epcot：{epoch}最后一轮参数保存成功！")